import sys
import time
import os
import os.path as osp
import requests
import shutil
import tqdm
import pickle
import numpy as np
import scipy.io as sio
import scipy.sparse as sp

import torch

from ..data import Data, Dataset, download_url

from . import register_dataset


def untar(path, fname, deleteTar=True):
    """
    Unpacks the given archive file to the same directory, then (by default)
    deletes the archive file.
    """
    print("unpacking " + fname)
    fullpath = os.path.join(path, fname)
    shutil.unpack_archive(fullpath, path)
    if deleteTar:
        os.remove(fullpath)


def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)


class HANDataset(Dataset):
    r"""The network datasets "ACM", "DBLP" and "IMDB" from the
    `"Heterogeneous Graph Attention Network"
    <https://arxiv.org/abs/1903.07293>`_ paper.

    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"han-acm"`,
            :obj:`"han-dblp"`, :obj:`"han-imdb"`).
    """

    def __init__(self, root, name):
        self.name = name
        self.url = (
            f"https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true"
        )
        super(HANDataset, self).__init__(root)
        self.data = torch.load(self.processed_paths[0])
        self.num_classes = torch.max(self.data.train_target).item() + 1
        self.num_edge = len(self.data.adj)
        self.num_nodes = self.data.x.shape[0]

    @property
    def raw_file_names(self):
        names = ["data.mat"]
        return names

    @property
    def processed_file_names(self):
        return ["data.pt"]

    def read_gtn_data(self, folder):
        data = sio.loadmat(osp.join(folder, "data.mat"))
        if self.name == "han-acm" or self.name == "han-imdb":
            truelabels, truefeatures = data["label"], data["feature"].astype(float)
        elif self.name == "han-dblp":
            truelabels, truefeatures = data["label"], data["features"].astype(float)
        num_nodes = truefeatures.shape[0]
        if self.name == "han-acm":
            rownetworks = [
                data["PAP"] - np.eye(num_nodes),
                data["PLP"] - np.eye(num_nodes),
            ]
        elif self.name == "han-dblp":
            rownetworks = [
                data["net_APA"] - np.eye(num_nodes),
                data["net_APCPA"] - np.eye(num_nodes),
                data["net_APTPA"] - np.eye(num_nodes),
            ]
        elif self.name == "han-imdb":
            rownetworks = [
                data["MAM"] - np.eye(num_nodes),
                data["MDM"] - np.eye(num_nodes),
                data["MYM"] - np.eye(num_nodes),
            ]

        y = truelabels
        train_idx = data["train_idx"]
        val_idx = data["val_idx"]
        test_idx = data["test_idx"]

        train_mask = sample_mask(train_idx, y.shape[0])
        val_mask = sample_mask(val_idx, y.shape[0])
        test_mask = sample_mask(test_idx, y.shape[0])

        y_train = np.argmax(y[train_mask, :], axis=1)
        y_val = np.argmax(y[val_mask, :], axis=1)
        y_test = np.argmax(y[test_mask, :], axis=1)

        data = Data()
        A = []
        for i, edge in enumerate(rownetworks):
            edge_tmp = torch.from_numpy(
                np.vstack((edge.nonzero()[0], edge.nonzero()[1]))
            ).type(torch.LongTensor)
            value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor)
            A.append((edge_tmp, value_tmp))
        edge_tmp = torch.stack(
            (torch.arange(0, num_nodes), torch.arange(0, num_nodes))
        ).type(torch.LongTensor)
        value_tmp = torch.ones(num_nodes).type(torch.FloatTensor)
        A.append((edge_tmp, value_tmp))
        data.adj = A

        data.x = torch.from_numpy(truefeatures).type(torch.FloatTensor)

        data.train_node = torch.from_numpy(train_idx[0]).type(torch.LongTensor)
        data.train_target = torch.from_numpy(y_train).type(torch.LongTensor)
        data.valid_node = torch.from_numpy(val_idx[0]).type(torch.LongTensor)
        data.valid_target = torch.from_numpy(y_val).type(torch.LongTensor)
        data.test_node = torch.from_numpy(test_idx[0]).type(torch.LongTensor)
        data.test_target = torch.from_numpy(y_test).type(torch.LongTensor)

        self.data = data

    def get(self, idx):
        assert idx == 0
        return self.data

    def apply_to_device(self, device):
        self.data.x = self.data.x.to(device)

        self.data.train_node = self.data.train_node.to(device)
        self.data.valid_node = self.data.valid_node.to(device)
        self.data.test_node = self.data.test_node.to(device)

        self.data.train_target = self.data.train_target.to(device)
        self.data.valid_target = self.data.valid_target.to(device)
        self.data.test_target = self.data.test_target.to(device)

        new_adj = []
        for (t1, t2) in self.data.adj:
            new_adj.append((t1.to(device), t2.to(device)))
        self.data.adj = new_adj

    def download(self):
        download_url(self.url, self.raw_dir, name=self.name + ".zip")
        untar(self.raw_dir, self.name + ".zip")

    def process(self):
        self.read_gtn_data(self.raw_dir)
        torch.save(self.data, self.processed_paths[0])

    def __repr__(self):
        return "{}()".format(self.name)


@register_dataset("han-acm")
class ACM_HANDataset(HANDataset):
    def __init__(self, path):
        dataset = "han-acm"
        # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
        super(ACM_HANDataset, self).__init__(path, dataset)


@register_dataset("han-dblp")
class DBLP_HANDataset(HANDataset):
    def __init__(self, path):
        dataset = "han-dblp"
        # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
        super(DBLP_HANDataset, self).__init__(path, dataset)


@register_dataset("han-imdb")
class IMDB_HANDataset(HANDataset):
    def __init__(self, path):
        dataset = "han-imdb"
        # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
        super(IMDB_HANDataset, self).__init__(path, dataset)
